load("@bazel_skylib//lib:paths.bzl", "paths")
load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
load("defs.bzl", "get_fbgemm_avx2_srcs", "get_fbgemm_avx512_srcs", "get_fbgemm_base_srcs", "get_fbgemm_generic_srcs", "get_fbgemm_public_headers", "get_fbgemm_tests")

cc_library(
    name = "fbgemm_base",
    srcs = get_fbgemm_base_srcs()  + glob(["src/*.h"]),
    includes = [
        "src",
    ],
    deps = [
        ":fbgemm_headers",
        "@cpuinfo",
        "@asmjit",
    ],
    linkstatic = 1,
)

cc_library(
    name = "fbgemm",
    visibility = ["//visibility:public"],
    srcs = get_fbgemm_generic_srcs() + glob(["src/*.h"]),
    includes = [
        "src",
    ],
    deps = [
        ":fbgemm_avx2",
        ":fbgemm_avx512",
        ":fbgemm_base",
    ],
    linkstatic = 1,
)

cc_library(
    name = "fbgemm_avx2",
    srcs = get_fbgemm_avx2_srcs() + glob(["src/*.h"]),
    copts = [
        "-m64",
        "-mavx2",
        "-mfma",
        "-mf16c",
        "-masm=intel",
    ],
    deps = [
        ":fbgemm_base",
    ],
    linkstatic = 1,
)

cc_library(
    name = "fbgemm_avx512",
    srcs = get_fbgemm_avx512_srcs() + glob(["src/*.h"]),
    copts = [
        "-m64",
        "-mfma",
        "-mavx512f",
        "-mavx512bw",
        "-mavx512dq",
        "-mavx512vl",
        "-masm=intel",
    ],
    deps = [
        ":fbgemm_base",
    ],
    linkstatic = 1,
)

cc_library(
    name = "fbgemm_headers",
    hdrs = get_fbgemm_public_headers(),
    includes = [
        "include",
    ],
    visibility = ["//visibility:public"],
)

# This header is included from pytorch/caffe2/quantization/server/conv_dnnlowp_op.cc
cc_library(
    name = "fbgemm_src_headers",
    hdrs = [
        "src/RefImplementations.h",
    ],
    include_prefix = "fbgemm",
    visibility = ["//visibility:public"],
)


cc_library(
    name = "test_utils",
    hdrs = glob(["test/*.h", "bench/*.h"]),
    srcs = [
        "bench/BenchUtils.cc",
        "test/EmbeddingSpMDMTestUtils.cc",
        "test/QuantizationHelpers.cc",
        "test/TestUtils.cc",
    ],
    includes = [
        "bench",
        "test",
    ],
    linkopts = [
        "-lrt",
    ],
    linkstatic = 1,
    deps = [
          ":fbgemm",
          "@com_google_googletest//:gtest_main",
    ],
)

[
  cc_test(
      name = paths.split_extension(paths.basename(filename))[0],
      size = "medium",
      srcs = [
          filename,
      ],
      deps = [
          ":test_utils",
      ],
  ) for filename in get_fbgemm_tests()
]
